import os

from transformers import LlamaTokenizer, LlamaForCausalLM

ckpt_path = '/home/zhangsenzhen/2023Q2/checkpoint_download/Ziya-LLaMA-13B-v1.1-hf/'

model = LlamaForCausalLM.from_pretrained(ckpt_path, device_map='auto')
tokenizer = LlamaTokenizer(os.path.join(ckpt_path, 'tokenizer.model'), use_fast=False)

while True:
    print("inputs:", end='')
    query = input()
    inputs = '<human>:' + query.strip() + '\n<bot>:'
    
    input_ids = tokenizer(inputs, return_tensors="pt")
    generate_ids = model.generate(
                input_ids["input_ids"],
                max_new_tokens=512,
                do_sample = False,
                top_k = 1,
                eos_token_id=2,
                bos_token_id=1,
                pad_token_id=0)
    output = tokenizer.batch_decode(generate_ids)[0]
    print(output)

